from matplotlib import pyplot as plt

total_steps = 28800
warmup_steps = 28800
d_model = 512

def lr(step):
    return (d_model ** -0.5) * min(step ** -0.5, step * (warmup_steps ** -1.5))


y1 = [lr(step) for step in range(1, total_steps + 1)]
print('max = ', max(y1))
x = range(1, total_steps + 1)

